import argparse
import os


def parse_args():
    parser = argparse.ArgumentParser(description="Pytorch implementation of GAN models.")

    parser.add_argument('--nodes', type=int, default=8, help='Number of nodes in distributed training')
    parser.add_argument('--n_gpu', type=int, default=8, help='Number of gpus in distributed training')
    parser.add_argument('--topo', type=str, default='exp', choices=['1', 'sep', '2', 'full', '3', 'exp', '4', 'ring', '5', 'dense']
                        , help='Topology used in distributed training')
    parser.add_argument('--opt_comm', action='store_true', help='If communicate optimizer parameters')

    parser.add_argument('--dataroot', default='dataset', help='path to dataset')
    parser.add_argument('--dataset', type=str, default='cifar', choices=['mnist', 'fashion-mnist', 'cifar', 'stl10'],
                            help='The name of dataset')
    parser.add_argument('--download', type=bool, default=False)
    parser.add_argument('--epochs', type=int, default=50, help='The number of epochs to run')
    parser.add_argument('--batch_size', type=int, default=512, help='The size of batch')
    parser.add_argument('--z_dim', type=int, default=50, help='latent variable dimension')
    parser.add_argument('--cuda',  type=str, default='True', help='Availability of cuda')
    parser.add_argument('--optim',  type=str, default='adam', help='optimizer to use')
    parser.add_argument('--lr',  type=float, default=1e-4, help='learning rate')
    parser.add_argument('--critic_iter',  type=int, default=5, help='number of critic iteration')
    parser.add_argument('--alpha',  type=float, default=0.6, help='parameter in TiAda')
    parser.add_argument('--beta',  type=float, default=0.4, help='parameter in TiAda')

    parser.add_argument('--generator_iters', type=int, default=40000, help='The number of iterations for generator in WGAN model.')
    
    parser.add_argument('--dist-url', default='tcp://127.0.0.1:55004', type=str,
                    help='url used to set up distributed training')
    parser.add_argument('--dist-backend', default='gloo', type=str,
                help='distributed backend')

    return check_args(parser.parse_args())


# Checking arguments
def check_args(args):
    # --epoch
    try:
        assert args.epochs >= 1
    except:
        print('Number of epochs must be larger than or equal to one')

    # --batch_size
    try:
        assert args.batch_size >= 1
    except:
        print('Batch size must be larger than or equal to one')

    if args.dataset == 'cifar' or args.dataset == 'stl10':
        args.channels = 3
    else:
        args.channels = 1
    args.cuda = True if args.cuda == 'True' else False
    return args
